class UnionFind {
    p: number[];
    size: number[];
    constructor(n: number) {
        this.p = Array(n)
            .fill(0)
            .map((_, i) => i);
        this.size = Array(n).fill(1);
    }

    find(x: number): number {
        if (this.p[x] !== x) {
            this.p[x] = this.find(this.p[x]);
        }
        return this.p[x];
    }

    union(a: number, b: number): boolean {
        const [pa, pb] = [this.find(a), this.find(b)];
        if (pa === pb) {
            return false;
        }
        if (this.size[pa] > this.size[pb]) {
            this.p[pb] = pa;
            this.size[pa] += this.size[pb];
        } else {
            this.p[pa] = pb;
            this.size[pb] += this.size[pa];
        }
        return true;
    }

    getSize(root: number): number {
        return this.size[root];
    }
}

function minMalwareSpread(graph: number[][], initial: number[]): number {
    const n = graph.length;
    const s = new Set(initial);
    const uf = new UnionFind(n);
    for (let i = 0; i < n; ++i) {
        if (!s.has(i)) {
            for (let j = i + 1; j < n; ++j) {
                if (graph[i][j] && !s.has(j)) {
                    uf.union(i, j);
                }
            }
        }
    }
    const g: Set<number>[] = Array.from({ length: n }, () => new Set());
    const cnt: number[] = Array(n).fill(0);
    for (const i of initial) {
        for (let j = 0; j < n; ++j) {
            if (graph[i][j] && !s.has(j)) {
                g[i].add(uf.find(j));
            }
        }
        for (const root of g[i]) {
            ++cnt[root];
        }
    }
    let ans = 0;
    let mx = -1;
    for (const i of initial) {
        let t = 0;
        for (const root of g[i]) {
            if (cnt[root] === 1) {
                t += uf.getSize(root);
            }
        }
        if (t > mx || (t === mx && i < ans)) {
            [ans, mx] = [i, t];
        }
    }
    return ans;
}
